
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.python.data.util import nest

from niftynet.layer.base_layer import Layer
from niftynet.io.image_reader import param_to_dict


class CSVReader(Layer):
    '''
    Class that performs the reading of the csv_input and select the lines
    according to the subject id if available
    '''
    def __init__(self, names=None):
        self.names = names
        self._paths = None
        self._labels = None
        self._df = None
        self.label_names = None
        self.n_samples_per_id = None
        self.dims = None
        self.data_param = None
        self._dims = None
        self._indexable_output = {}
        self.file_list = None
        self.subject_ids = None
        self._shapes = {}
        self._input_sources = None
        self.df_by_task = {}
        self.valid_by_task = {}
        self.pad_by_task = {}
        self.dims_by_task = {}
        self.type_by_task = {}
        self._dtypes = {}
        self.task_param = None
        super(CSVReader, self).__init__(name='csv_reader')
    
    def initialise(self, data_param, task_param=None, file_list=None,
                   sample_per_volume=1):
        """
        this function takes in a data_param specifying the name of the source and the location of
        the csv data. Three input modes are supported:
        - 'label' - expects a csv with header subject_id,label.
        - 'features' - expects a csv with header subject_id,<name of feature 1>,<name of feature 2>
        e.g.::

             data_param = {'label': {'csv_data_file': 'path/to/some_data.csv', 'to_ohe': False}}
             
        :param data_param: dictionary of input sections
        :param task_param: Namespace object
        :param file_list: a dataframe generated by ImagePartitioner
        :param sample_per_volume: number of samples taken per volume (useful
        to know how much to tile the csv output
            for cross validation, so
            that the reader only loads files in training/inference phases.
        """
        assert self.names is not None
        data_param = param_to_dict(data_param)
        self.n_samples_per_id = sample_per_volume
        print(data_param)
        if not task_param:
            task_param = {mod: (mod,) for mod in list(data_param)}
        try:
            if not isinstance(task_param, dict):
                task_param = vars(task_param)
        except ValueError:
            tf.logging.fatal(
                "To concatenate multiple input data arrays,\n"
                "task_param should be a dictionary in the form:\n"
                "{'new_modality_name': ['modality_1', 'modality_2',...]}.")
            raise
        self.task_param = task_param
        valid_names = [name for name in self.names if self.task_param.get(
            name, None)]
        if not valid_names:
            tf.logging.fatal("CSVReader requires task input keywords %s, but "
                             "not exist in the config file.\n"
                             "Available task keywords: %s",
                             self.names, list(self.task_param))
            raise ValueError
        self.names = valid_names
        self.data_param = data_param
        self._dims = None
        self._indexable_output = {}
        self.file_list = file_list
        self.subject_ids = self.file_list['subject_id'].values

        self._input_sources = dict((name, self.task_param.get(name))
                                   for name in self.names)
        self.df_by_task = {}
        self.valid_by_task = {}
        self.pad_by_task = {}
        self.dims_by_task = {}
        self.type_by_task = {}

        for name in valid_names:
            df_fin, _indexable_output, _dims = self._parse_csv(
                path_to_csv=data_param[name].get('csv_data_file', None),
                to_ohe=data_param[name].get('to_ohe', False)
            )
            self.df_by_task[name] = df_fin
            self.dims_by_task[name] = _dims
            self._indexable_output[name] = _indexable_output
            self.valid_by_task[name] = -1 * np.ones( [self.df_by_task[name].shape[0]])  # -1 means they have not been checked

            self.pad_by_task[name] = np.zeros(
                [self.df_by_task[name].shape[0], 2*_dims])
            if df_fin.shape[0] > len(set(self.subject_ids)):
                self.type_by_task[name] = 'multi'
            else:
                self.type_by_task[name] = 'mono'
        # Converts Dictionary of Lists to List of Dictionaries
        # self._indexable_output = pd.DataFrame(
        # self._indexable_output).to_dict('records')
        assert file_list is not None
        return self
    
    def _parse_csv(self, path_to_csv, to_ohe):
        tf.logging.warning('This method will read your entire csv into memory')
        df_init = pd.read_csv(path_to_csv, index_col=0, header=None)

        df_init.index = df_init.index.map(str)

        if set(df_init.index) != set(self.subject_ids):
            print("probably different because of split file - drop not "
                  "relevant ones")
            df_fin = df_init.drop(index=[s for s in set(df_init.index) if s
                                         not in set(self.subject_ids)])
            # df.reset_index(drop=True, inplace=True)

            if set(df_fin.index) != set(self.subject_ids):
                print(set(self.subject_ids) - set(df_fin.index))
                tf.logging.fatal('csv file provided at: {} does not have '
                                 'all the subject_ids'.format(path_to_csv))
                raise Exception
        else:
            df_fin = df_init.copy()
        if to_ohe and len(df_fin.columns) == 1:
            _dims = len(list(df_fin[1].unique()))
            _indexable_output = self.to_ohe(df_fin[1].values, _dims)
            return df_fin, _indexable_output, _dims
        elif not to_ohe and len(df_fin.columns) == 1:
            _dims = 1
            _indexable_output = self.to_categorical(df_fin[1].values,
                                                    np.sort(df_fin[1].unique()))
            return df_fin, _indexable_output, _dims
        elif not to_ohe:
            _dims = len(df_fin.columns)
            _indexable_output = list(df_fin.values)
            return df_fin, _indexable_output, _dims
        tf.logging.fatal('Unrecognised input format for {}'.format(path_to_csv))
        raise Exception('Unrecognised input format for {}'.format(path_to_csv))

    @staticmethod
    def to_ohe(labels, _dims):
        '''
        Transform the labeling to one hot encoding
        :param labels: labels to encode
        :param _dims:
        :return:
        '''
        label_names = list(set(labels))
        ohe = [np.eye(_dims)[label_names.index(label)].astype(np.float32)
               for label in labels]
        return ohe

    @staticmethod
    def to_categorical(labels, label_names):
        '''
        Transformation of labels to categorical
        :param labels: labels to change
        :param label_names:
        :return:
        '''
        return [np.array(list(label_names).index(label)).astype(np.float32)
                for label in labels]

    def layer_op(self, idx=None, subject_id=None, mode='single', reject=True):
        '''
        Perform the csv_reading and assignment to dictionary
        :param idx: index of the image
        :param subject_id: subject id
        :param mode: chosen mode (multi or single)
        :param reject: if some elements should be rejected
        :return:
        '''
        if idx is None and subject_id is not None:
            print("Need to decide upon idx from subject %s" % subject_id)
            idx_dict = {}
            if mode == 'single':
                print("Taking only one index among other valid")
                #  Take the list of idx corresponding to subject id and randomly
                # sample from there
                for name in self.names:

                    relevant_indices = self.df_by_task[name].reset_index()[
                        self.df_by_task[name].reset_index()[0] == subject_id].index.values
                    if reject:
                        relevant_valid = np.asarray(np.where(np.abs(
                            self.valid_by_task[name][relevant_indices]) > 0)[0])
                        if relevant_valid is None:
                            relevant_valid = []
                        print(relevant_valid, reject)
                    else:
                        relevant_valid = np.arange(len(relevant_indices))
                        print(relevant_valid, " is list of indices to sample "
                                              "from")
                    print(np.asarray(relevant_valid).shape[0], "is shape of "
                          "relevant_valid")
                    relevant_final = [relevant_indices[v] for v in
                                      relevant_valid] if \
                        np.asarray(relevant_valid).shape[0] > 0 else []
                    print(relevant_final, "is relevant final")
                    idx_dict[name] = random.choice(relevant_final) if \
                        list(relevant_final) else []

            else: #self.df_by_task[self.df_by_task[name] == subject_id]
                # mode full i.e. output all the lines corresponding to
                # subject_id
                print(" Taking all valid indices")
                for name in self.names:

                    relevant_indices = self.df_by_task[name].reset_index()[
                        self.df_by_task[name].reset_index()[0] == subject_id].index.values
                    if reject:
                        relevant_valid = np.asarray(np.where(np.abs(
                            self.valid_by_task[name][relevant_indices]) > 0)[0])
                        if relevant_valid is None:
                            relevant_valid = []
                    else:
                        relevant_valid = np.arange(len(relevant_indices))
                    relevant_final = [relevant_indices[v] for v in
                                      relevant_valid] if \
                        np.asarray(relevant_valid).shape[0] > 0 else []
                    idx_dict[name] = relevant_final

        elif idx is None and subject_id is None:
            idx_dict = {}
            print("Need to also choose subject id")
            for name in self.names:
                if subject_id is None:
                    idx_dict[name] = np.random.randint(
                        self.df_by_task[name].shape[0])
                    subject_id = self.df_by_task[name].iloc[idx_dict[name]].name
                    print("new subject id is ", subject_id)
                if mode == 'single':
                    #  Take the list of idx corresponding to subject id
                    # and randomly sample from there
                    print("Need to find index in single mode")

                    relevant_indices = np.asarray(np.where(self.df_by_task[
                        name].index.get_loc(subject_id))[0])
                    # print("Found initial relevant", relevant_indices,
                    #       set(self.df_by_task[name].index), name,
                    #       self.df_by_task[name].index.get_loc(subject_id).shape)
                    if reject:
                        relevant_valid = np.asarray(np.where(np.abs(
                            self.valid_by_task[name][relevant_indices]) > 0)[0])
                    else:
                        relevant_valid = np.arange(len(relevant_indices))
                    # print("Found corresponding valid", relevant_valid,
                    #       np.max(relevant_valid), relevant_indices.shape)
                    relevant_final = [relevant_indices[v] for v in
                                      relevant_valid]
                    # print(relevant_indices, subject_id)
                    # relevant_indices = self._df.loc[subject_id]
                    assert list(relevant_final), 'no valid index for subject ' \
                        '%s and field %s' % (subject_id, name)
                    idx_dict[name] = random.choice(relevant_final)
                else:  # mode full i.e. output all the lines corresponding to
                    # subject_id
                    relevant_indices = np.asarray(np.where(self.df_by_task[
                        name].index.get_loc(subject_id))[0])
                    # print("Found initial relevant", relevant_indices,
                    #       set(self.df_by_task[name].index), name,
                    #       self.df_by_task[name].index.get_loc(subject_id).shape)
                    if reject:
                        relevant_valid = np.asarray(np.where(np.abs(
                            self.valid_by_task[name][relevant_indices]) > 0)[0])
                    else:
                        relevant_valid = np.ones_like(relevant_indices)
                    # print("Found corresponding valid", relevant_valid,
                    #       np.max(relevant_valid), relevant_indices.shape)
                    relevant_final = [relevant_indices[v] for v in
                                      relevant_valid]
                    assert list(relevant_final), 'no valid index for subject ' \
                        '%s and field %s' % (subject_id, name)
                    idx_dict[name] = relevant_final
        elif not isinstance(idx, dict):
            idx_dict = {}
            for name in self.names:
                idx_dict[name] = idx
                if subject_id is None:
                    subject_id = self.df_by_task[name].iloc[idx_dict[name]].name
        else:
            idx_dict = {}
            for name in self.names:
                idx_dict[name] = idx[name]
                assert list(idx[name]), 'no valid index for %s' % name
                if subject_id is None:
                    subject_id = self.df_by_task[name].iloc[idx_dict[name]].name

        if self._indexable_output is not None:
            output_dict = {k: self.apply_niftynet_format_to_data(
                self.tile_nsamples(np.asarray(
                    self._indexable_output[k])[idx_dict[k]])) for k in
                           idx_dict.keys()}
            # print(idx_dict, self._indexable_output['modality_label'][
            #     idx_dict['modality_label']])
            return idx_dict, output_dict, subject_id
        raise Exception('Invalid mode')


    def tile_nsamples(self, data):
        '''
        Tile the csv_read to have the same value applied to the nsamples
        extracted from the volume
        :param data: csv data to tile for all samples extracted in the volume
        :return: tiled data
        '''
        if self.n_samples_per_id > 1:
            print("preparing tiling")
            data = np.expand_dims(data, 1)
            data = np.tile(data, np.asarray(
                np.concatenate(
                    ([self.n_samples_per_id],
                     [1, ]*(len(np.asarray(data.shape))))),
                dtype=np.int))
            print("tiling done", data.shape)
            return data

        else:
            return data
    
    @property
    def shapes(self):
        """
        :return: dict of label shape and label location shape
        """
        for name in self.names:
            if self.n_samples_per_id == 1:
                self._shapes.update({name: (1, self.dims_by_task[name],
                                            1, 1, 1, 1),
                                     name + '_location': (1, 7)})
            else:
                self._shapes.update(
                    {name: (self.n_samples_per_id,
                            self.dims_by_task[name],
                            1, 1, 1, 1),
                     name + '_location': (self.n_samples_per_id, 7)})
        return self._shapes

    @property
    def tf_dtypes(self):
        """
        Infer input data dtypes in TF
        """
        for name in self.names:
            self._dtypes.update({name: tf.float32,
                                 name + '_location': tf.int32})
        return self._dtypes
    
    @property
    def tf_shapes(self):
        """
        :return: a dictionary of sampler output tensor shapes
        """
        output_shapes = nest.map_structure_up_to(
            self.tf_dtypes, tf.TensorShape, self.shapes)
        return output_shapes

    
    @staticmethod
    def apply_niftynet_format_to_data(data):
        '''
<<<<<<< HEAD
        Transform the dtaa to be of dimension 5d
=======
        Transform the data to be of dimension 5d
>>>>>>> 7a0386e78f01c88b707e08f759f910abba9b71b1
        :param data: data to expand
        :return: expanded data
        '''
        if len(data.shape) == 1:
            data = np.expand_dims(data, 0)
        while len(data.shape) < 6:
            data = np.expand_dims(data, -1)
        return data

